from pyspark.ml.classification import DecisionTreeClassifier, DecisionTreeClassificationModel
from utils.spark_util import col_validation, one_hot_encoding, feature_cols_labelling, label_col_labelling, \
    get_class_model_info, one_hot_encoding_p, cast_to_float, col_name_check, col_decoder, get_distincts, col_encoder, \
    train_test_split, multiclass_eval, get_confusion_matrix

MAX_CAT_CM_TOLERANCE = 10
MAX_CAT_LABEL_TOLERNACE = 1500
import logging

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")


def predict(sc, dataframe, para_list, output):
    # --------------------prepare feature data------------

    num_cols, cat_cols = col_validation(dataframe, para_list["feature_col"])
    model_save_path, feature_list_train, label_list = get_class_model_info(sc, para_list["model_id"])
    cat_col_dict = {}
    i = -1
    for col in feature_list_train:
        if "_OneHot/" in col:
            feature, category = col.split("_OneHot/")
            if feature in cat_col_dict.keys():
                cat_col_dict[cat_cols[i]].append(category)
            else:
                i += 1
                cat_col_dict[cat_cols[i]] = [category]

    if len(cat_cols) != 0:
        dataframe, message, encoded_cat_cols = one_hot_encoding_p(sc, dataframe, para_list["source"], cat_cols
                                                                  , cat_col_dict)
        if len(message) != 0:
            output["message"].append(message)
    else:
        encoded_cat_cols = []

    feature_list = num_cols + encoded_cat_cols
    df = feature_cols_labelling(dataframe, feature_list)

    # --------------------load the model-----------------

    model = DecisionTreeClassificationModel.load(model_save_path)

    # --------------------make predictions----------------

    df_pred = model.transform(df)
    df_pred = cast_to_float(df_pred, "prediction").drop("features").drop("label") \
        .drop("rawPrediction").drop("probability")

    new_col_name = col_name_check("_classification_", df_pred.columns)

    df_pred = df_pred.withColumnRenamed("prediction", new_col_name)

    df_pred = col_decoder(df_pred, new_col_name, label_list)

    output_cols = []
    for feature in para_list["feature_col"]:
        output_cols.append(feature)
    output_cols.append(new_col_name)
    output["result"]["output_params"]["output_cols"] = output_cols
    if len(encoded_cat_cols) != 0:
        for col in encoded_cat_cols:
            df_pred = df_pred.drop(col)
    df_pred.show(4)
    return df_pred


def train(sc, dataframe, para_list, record, output):
    # --------------------prepare data------------------
    label_list = get_distincts(sc, para_list["source"], para_list["label_col"])
    if len(label_list) > MAX_CAT_LABEL_TOLERNACE:
        output["error"].append("Error: Label column has more than 1500 categories.")
        return None, None

    dataframe = col_encoder(dataframe, para_list["label_col"], label_list)

    num_cols, cat_cols = col_validation(dataframe, para_list["feature_col"])
    print("COLS")
    print(num_cols, cat_cols)
    if len(cat_cols) != 0:
        dataframe, message, encoded_cat_cols = one_hot_encoding(sc, dataframe, para_list["source"], cat_cols)
        if len(message) != 0:
            output["message"].append(message)
    else:
        encoded_cat_cols = []

    feature_list = num_cols + encoded_cat_cols

    df = feature_cols_labelling(dataframe, feature_list)
    df = label_col_labelling(df, para_list["label_col"] + "_encoded")

    train, test = train_test_split(df, para_list["split_rate"])

    # ---------------------model fit---------------------
    print("DTC MODEL FITTING")
    dt = DecisionTreeClassifier(labelCol="label", featuresCol="features"
                                , maxDepth=para_list["max_depth"])

    model = dt.fit(train)
    print("MODEL FITTED")
    # ---------------------model validation---------------

    test_pred_raw = model.transform(test)

    test_pred_real = col_decoder(test_pred_raw, "prediction", label_list)

    # ---------------------get metrics--------------------

    accuracy_test, precision_test, recall_test, F1_test = multiclass_eval(test_pred_raw)
    print("CALCULATING CM")
    if len(label_list) < MAX_CAT_CM_TOLERANCE:
        confusion_matrix = get_confusion_matrix(label_list, test_pred_real, para_list)
    else:
        confusion_matrix = []

    print("METRICS GET")
    # ---------------------record info----------------------
    other_info = {
        "feature_list": feature_list,
        "feature_col": para_list["feature_col"],
        "label_list": label_list,
        "feature_importances": list(model.featureImportances),
        "accuracy_test": accuracy_test,
        "Fmeasure_test": F1_test,
        "precision_test": precision_test,
        "recall_test": recall_test,
        "confusion_matrix": str(confusion_matrix)
    }
    record["other_info"] = other_info
    print("RETURN")
    return model
